{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "*Note: MXNet supplys an official implementation of deformable convolution, here tests this repo against the MXNet's implementation.*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import mxnet as mx\n",
    "from torch import nn\n",
    "from time import time\n",
    "from pprint import pprint\n",
    "from torch.autograd import Variable\n",
    "from mxnet.initializer import Initializer\n",
    "from deform_conv import DeformConv2D"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Set up parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using gpu0\n"
     ]
    }
   ],
   "source": [
    "bs, inC, ouC, H, W = 1, 1, 1, 4, 5\n",
    "kH, kW = 3, 3\n",
    "padding = 1\n",
    "\n",
    "# ---------------------------------------\n",
    "use_gpu = torch.cuda.is_available()\n",
    "gpu_device = 0\n",
    "if use_gpu:\n",
    "    os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(gpu_device)\n",
    "    print(\"Using gpu{}\".format(os.getenv(\"CUDA_VISIBLE_DEVICES\")))\n",
    "# ---------------------------------------\n",
    "raw_inputs = np.random.rand(bs, inC, H, W).astype(np.float32)\n",
    "raw_labels = np.random.rand(bs, ouC, (H+2*padding-2)//1, (W+2*padding-2)//1).astype(np.float32)\n",
    "# weights for conv offsets.\n",
    "offset_weights = np.random.rand(18, inC, 3, 3).astype(np.float32)\n",
    "# weights for deformable convolution.\n",
    "conv_weights = np.random.rand(ouC, inC, 3, 3).astype(np.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "inputs:\n",
      "array([[[[ 0.56075788,  0.56448251,  0.38643569,  0.13775933,  0.92719644],\n",
      "         [ 0.18066591,  0.24222445,  0.29689947,  0.54874617,  0.95001829],\n",
      "         [ 0.61031544,  0.84815538,  0.27238497,  0.53376287,  0.93240666],\n",
      "         [ 0.54890364,  0.60794067,  0.1237376 ,  0.16012843,  0.82202536]]]], dtype=float32)\n",
      "\n",
      "labels:\n",
      "array([[[[ 0.68657815,  0.10542295,  0.04489666,  0.64058632,  0.52095002],\n",
      "         [ 0.73867059,  0.91901845,  0.80943078,  0.2182935 ,  0.02595145],\n",
      "         [ 0.31954384,  0.80359656,  0.53808153,  0.46827996,  0.90268624],\n",
      "         [ 0.84400773,  0.5750683 ,  0.55033565,  0.11278367,  0.47512576]]]], dtype=float32)\n",
      "\n",
      "conv weights:\n",
      "array([[[[ 0.89284414,  0.98574871,  0.94764489],\n",
      "         [ 0.69642198,  0.84854221,  0.98900223],\n",
      "         [ 0.82735974,  0.4257046 ,  0.59915102]]]], dtype=float32)\n"
     ]
    }
   ],
   "source": [
    "print('\\ninputs:')\n",
    "pprint(raw_inputs)\n",
    "print('\\nlabels:')\n",
    "pprint(raw_labels)\n",
    "print('\\nconv weights:')\n",
    "pprint(conv_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###  Set up models of PyTorch&MXNet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Set PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TestModel(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(TestModel, self).__init__()\n",
    "        self.conv_offset = nn.Conv2d(in_channels=inC, out_channels=18, kernel_size=3, padding=padding, bias=None)\n",
    "        self.deform_conv = DeformConv2D(inc=inC, outc=ouC, padding=padding)\n",
    "\n",
    "    def forward(self, x):\n",
    "        offsets = self.conv_offset(x)\n",
    "        out = self.deform_conv(x, offsets)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = TestModel()\n",
    "\n",
    "pt_inputs = Variable(torch.from_numpy(raw_inputs).cuda(), requires_grad=True)\n",
    "pt_labels = Variable(torch.from_numpy(raw_labels).cuda(), requires_grad=False)\n",
    "\n",
    "optimizer = torch.optim.SGD([{'params': model.parameters()}], lr=1e-1)\n",
    "loss_fn = torch.nn.MSELoss(reduce=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Init weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Conv2d (1, 18, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def init_weights(m):\n",
    "    if isinstance(m, torch.nn.Conv2d):\n",
    "        m.weight.data = torch.from_numpy(conv_weights)\n",
    "        if m.bias is not None:\n",
    "            m.bias.data = torch.FloatTensor(m.bias.shape[0]).zero_()\n",
    "\n",
    "def init_offsets_weights(m):\n",
    "    if isinstance(m, torch.nn.Conv2d):\n",
    "        m.weight.data = torch.from_numpy(offset_weights)\n",
    "        if m.bias is not None:\n",
    "            m.bias.data = torch.FloatTensor(m.bias.shape[0]).zero_()\n",
    "\n",
    "model.deform_conv.apply(init_weights)\n",
    "model.conv_offset.apply(init_offsets_weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "if use_gpu:\n",
    "    model.cuda()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Set MXNet model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# trainiter\n",
    "train_iter = mx.io.NDArrayIter(raw_inputs, raw_labels, 1, shuffle=True, data_name='data', label_name='label')\n",
    "\n",
    "# # symbol\n",
    "inputs = mx.symbol.Variable('data')\n",
    "labels = mx.symbol.Variable('label')\n",
    "offsets = mx.symbol.Convolution(data=inputs, kernel=(3, 3), pad=(padding, padding), num_filter=18, name='offset', no_bias=True)\n",
    "net = mx.symbol.contrib.DeformableConvolution(data=inputs, offset=offsets, kernel=(3, 3), pad=(padding, padding), num_filter=ouC, name='deform', no_bias=True)\n",
    "outputs = mx.symbol.MakeLoss(data=mx.symbol.mean((net-labels)**2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "mod = mx.mod.Module(symbol=outputs,\n",
    "                    context=mx.gpu(),\n",
    "                    data_names=['data'],\n",
    "                    label_names=['label'])\n",
    "\n",
    "mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)\n",
    "mod.init_params(initializer=mx.initializer.Load({'deform_weight': mx.nd.array(conv_weights),\n",
    "                                                 'offset_weight': mx.nd.array(offset_weights)}))\n",
    "mod.init_optimizer(optimizer='sgd', optimizer_params=(('learning_rate', 0.1),))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      "(0 ,0 ,.,.) = \n",
      "  1.1062  2.6770  3.6377  2.7072  1.3894\n",
      "  2.7560  3.1664  3.4931  0.8367  0.5725\n",
      "  2.1425  1.1672  2.4347  1.5723  0.0000\n",
      "  1.5282  0.7295  1.6247  2.0247  1.5443\n",
      "[torch.cuda.FloatTensor of size 1x1x4x5 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "output = model(pt_inputs)\n",
    "pprint(output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### MXNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[[[[ 1.10622048  2.67702818  3.63765717  2.70719433  1.38943076]\n",
      "   [ 2.75602031  3.16641665  3.49313188  0.83671159  0.57247651]\n",
      "   [ 2.14249563  1.16722107  2.43471932  1.57227027  0.        ]\n",
      "   [ 1.52822971  0.72951925  1.62470031  2.02470279  1.54425097]]]]\n",
      "<NDArray 1x1x4x5 @gpu(0)>\n"
     ]
    }
   ],
   "source": [
    "mx_inputs = mx.nd.array(raw_inputs, ctx=mx.gpu())\n",
    "conv_weights = mx.nd.array(conv_weights, ctx=mx.gpu())\n",
    "offset_weights = mx.nd.array(offset_weights, ctx=mx.gpu())\n",
    "offset = mx.ndarray.Convolution(data=mx_inputs, weight=offset_weights, kernel=(3, 3), pad=(padding, padding), num_filter=18, name='offset', no_bias=True)\n",
    "outputs = mx.ndarray.contrib.DeformableConvolution(data=mx_inputs, offset=offset, weight=conv_weights, kernel=(3, 3), pad=(padding, padding), num_filter=ouC, name='deform', no_bias=True)\n",
    "pprint(outputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    output = model(pt_inputs)\n",
    "    loss = loss_fn(output, pt_labels)\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Variable containing:\n",
      "(0 ,0 ,.,.) = \n",
      "  0.1410  0.3813  0.3941  0.7022  0.3028\n",
      "  0.4378  0.7259  0.8587  0.3047 -0.0288\n",
      "  0.5056  0.3441  0.6582  0.5480  0.0000\n",
      "  0.4662  0.2311  0.4464  0.5114  0.5383\n",
      "[torch.cuda.FloatTensor of size 1x1x4x5 (GPU 0)]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pprint(model(pt_inputs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### MXNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(100):\n",
    "    train_iter.reset()\n",
    "    for batch in train_iter:\n",
    "        # get outputs\n",
    "#         infer_outputs = mx.mod.Module(symbol=net,\n",
    "#                                      context=mx.gpu(),\n",
    "#                                      data_names=['data'])\n",
    "#         infer_outputs.bind(data_shapes=train_iter.provide_data)\n",
    "#         infer_outputs.set_params(arg_params=mod.get_params()[0], aux_params=mod.get_params()[1], allow_extra=True)\n",
    "#         outputs_value = infer_outputs.predict(train_iter)\n",
    "\n",
    "        mod.forward(batch, is_train=True)  # compute predictions\n",
    "        mod.backward()  # compute gradients\n",
    "        mod.update()  # update parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "[[[[ 0.14098766  0.38126716  0.39405173  0.70216376  0.30278912]\n",
      "   [ 0.43777964  0.72585207  0.8587358   0.30472821 -0.02884051]\n",
      "   [ 0.50558764  0.34406984  0.65824842  0.54796678  0.        ]\n",
      "   [ 0.46616155  0.23111115  0.44640523  0.51136774  0.53833312]]]]\n",
      "<NDArray 1x1x4x5 @gpu(0)>\n"
     ]
    }
   ],
   "source": [
    "mx_inputs = mx.nd.array(raw_inputs, ctx=mx.gpu())\n",
    "mx_labels = mx.nd.array(raw_labels, ctx=mx.gpu())\n",
    "conv_weights = mod.get_params()[0]['deform_weight'].as_in_context(mx.gpu())\n",
    "offset_weights = mod.get_params()[0]['offset_weight'].as_in_context(mx.gpu())\n",
    "offset = mx.ndarray.Convolution(data=mx_inputs, weight=offset_weights, kernel=(3, 3), pad=(padding, padding), num_filter=18, name='offset', no_bias=True)\n",
    "outputs = mx.ndarray.contrib.DeformableConvolution(data=mx_inputs, offset=offset, weight=conv_weights, kernel=(3, 3), pad=(padding, padding), num_filter=ouC, name='deform', no_bias=True)\n",
    "pprint(outputs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
